{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SmoothQuant on Llama 2 7B\n",
    "\n",
    "In this notebook, we use Llama-2-7B model to demonstrate SmoothQuant can use 8-bit for both weights and activations to achieve the similar perplexity as FP16 models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to run this notebook, you need to install the following packages:\n",
    "\n",
    "- smoothquant\n",
    "- PyTorch\n",
    "- Transformers\n",
    "- Accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from transformers.models.llama.modeling_llama import (\n",
    "    LlamaAttention,\n",
    "    LlamaDecoderLayer,\n",
    "    LlamaForCausalLM,\n",
    "    LlamaMLP,\n",
    ")\n",
    "from transformers import LlamaTokenizer\n",
    "from smoothquant.smooth import smooth_lm\n",
    "from smoothquant.fake_quant import quantize_llama_like\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following is an evaluator to see the performance of the model. We use a toy dataset (the first 40 examples in the test set of the Wikitext-2 dataset) to evaluate the model. You can replace it with your own dataset. The conclusion should be the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Evaluator:\n",
    "    def __init__(self, dataset, tokenizer, device, n_samples=40):\n",
    "        self.dataset = dataset\n",
    "        self.tokenizer = tokenizer\n",
    "        self.device = device\n",
    "\n",
    "        self.dataset = tokenizer(\n",
    "            \"\\n\\n\".join(dataset[\"text\"]), return_tensors=\"pt\"\n",
    "        ).input_ids.to(device)\n",
    "\n",
    "        self.n_samples = n_samples\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def evaluate(self, model):\n",
    "        model.eval()\n",
    "        nlls = []\n",
    "        for i in tqdm.tqdm(range(self.n_samples), desc=\"Evaluating...\"):\n",
    "            batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)\n",
    "            with torch.no_grad():\n",
    "                lm_logits = model(batch).logits\n",
    "            shift_logits = lm_logits[:, :-1, :].contiguous().float()\n",
    "            shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]\n",
    "            loss_fct = nn.CrossEntropyLoss()\n",
    "            loss = loss_fct(\n",
    "                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)\n",
    "            )\n",
    "            neg_log_likelihood = loss.float() * 2048\n",
    "            nlls.append(neg_log_likelihood)\n",
    "\n",
    "        return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/NFS/home/guangxuan/anaconda3/envs/ellm/lib/python3.8/site-packages/pandas/core/computation/expressions.py:20: UserWarning: Pandas requires version '2.7.3' or newer of 'numexpr' (version '2.7.2' currently installed).\n",
      "  from pandas.core.computation.check import NUMEXPR_INSTALLED\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "tokenizer = LlamaTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\")\n",
    "dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')\n",
    "evaluator = Evaluator(dataset, tokenizer, \"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## FP16 Model Perplexity"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first check the performance of the original FP16 model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27d329da8d934f258323e1977002cebd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_fp16 = LlamaForCausalLM.from_pretrained(\n",
    "    \"meta-llama/Llama-2-7b-hf\", torch_dtype=torch.float16, device_map=\"auto\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating...: 100%|██████████| 40/40 [00:09<00:00,  4.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original model (fp16) perplexity: 5.822948932647705\n"
     ]
    }
   ],
   "source": [
    "ppl_fp16 = evaluator.evaluate(model_fp16)\n",
    "print(f\"Original model (fp16) perplexity: {ppl_fp16}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then quantize the model to W8A8 and check the performance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Naive W8A8 Quantized Model Perplexity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LlamaForCausalLM(\n",
      "  (model): LlamaModel(\n",
      "    (embed_tokens): Embedding(32000, 4096)\n",
      "    (layers): ModuleList(\n",
      "      (0-31): 32 x LlamaDecoderLayer(\n",
      "        (self_attn): LlamaSdpaAttention(\n",
      "          (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (rotary_emb): LlamaRotaryEmbedding()\n",
      "        )\n",
      "        (mlp): LlamaMLP(\n",
      "          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (down_proj): W8A8Linear(11008, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (act_fn): SiLU()\n",
      "        )\n",
      "        (input_layernorm): LlamaRMSNorm()\n",
      "        (post_attention_layernorm): LlamaRMSNorm()\n",
      "      )\n",
      "    )\n",
      "    (norm): LlamaRMSNorm()\n",
      "  )\n",
      "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model_w8a8 = quantize_llama_like(model_fp16)\n",
    "print(model_w8a8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating...: 100%|██████████| 40/40 [00:07<00:00,  5.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Naive W8A8 quantized model perplexity: 5.931240558624268\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "ppl_w8a8 = evaluator.evaluate(model_w8a8)\n",
    "print(f\"Naive W8A8 quantized model perplexity: {ppl_w8a8}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see there is a perplexity increase. We then use SmoothQuant to quantize the model and check the performance."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SmoothQuant W8A8 Quantized Model Perplexity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "321b7c77ef984e6f82c74c8b6b16e558",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LlamaForCausalLM(\n",
      "  (model): LlamaModel(\n",
      "    (embed_tokens): Embedding(32000, 4096)\n",
      "    (layers): ModuleList(\n",
      "      (0-31): 32 x LlamaDecoderLayer(\n",
      "        (self_attn): LlamaSdpaAttention(\n",
      "          (q_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (k_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (v_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (o_proj): W8A8Linear(4096, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (rotary_emb): LlamaRotaryEmbedding()\n",
      "        )\n",
      "        (mlp): LlamaMLP(\n",
      "          (gate_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (up_proj): W8A8Linear(4096, 11008, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (down_proj): W8A8Linear(11008, 4096, bias=False, weight_quant=per_channel, act_quant=per_token, output_quant=None)\n",
      "          (act_fn): SiLU()\n",
      "        )\n",
      "        (input_layernorm): LlamaRMSNorm()\n",
      "        (post_attention_layernorm): LlamaRMSNorm()\n",
      "      )\n",
      "    )\n",
      "    (norm): LlamaRMSNorm()\n",
      "  )\n",
      "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = LlamaForCausalLM.from_pretrained(\n",
    "    \"meta-llama/Llama-2-7b-hf\", torch_dtype=torch.float16, device_map=\"auto\"\n",
    ")\n",
    "act_scales = torch.load(\"../act_scales/llama-2-7b.pt\")\n",
    "smooth_lm(model, act_scales, 0.85)\n",
    "model_smoothquant_w8a8 = quantize_llama_like(model)\n",
    "print(model_smoothquant_w8a8)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the smoothed model has a lower perplexity which is close to the FP16 model's. This is because SmoothQuant smooths the outliers in activations and balances the quantization difficulty of activations and weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluating...: 100%|██████████| 40/40 [00:08<00:00,  4.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SmoothQuant W8A8 quantized model perplexity: 5.85634183883667\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "ppl_smoothquant_w8a8 = evaluator.evaluate(model_smoothquant_w8a8)\n",
    "print(f\"SmoothQuant W8A8 quantized model perplexity: {ppl_smoothquant_w8a8}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 (conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "c458cb81aeeb610631c72e4cc4799f00f630d4dfa7a554b37f8134a7fe160cb8"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
